KV Cache for flow DiT#1875
Open
hcwang26 wants to merge 1 commit intoFunAudioLLM:mainfrom
Open
Conversation
|
nice pr |
Collaborator
|
感谢感谢,之前尝试cache遇到的问题是随着时长的增加,k/v cache占的显存显著增大,请问你这边有测比如40G显存能支持多少并发吗? |
这个应该跟同时decode的音频长度跟正相关,可以加个联系方式吗 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds KV-cache / chunked streaming support for the flow DiT decoder (CosyVoice3), enabling incremental causal diffusion inference with reusable attention & conv caches. Fully backward-compatible with the existing CosyVoice2 UNet path.
Changes
cosyvoice/flow/DiT/modules.py: Custom einops-based apply_rotary_pos_emb with offset argument for chunk-aware RoPE. CausalConvPositionEmbedding.forward now accepts and updates a conv_cache. AttnProcessor.call accepts x_offset / att_cache, splits & concatenates past K/V, pads attn_mask on cache growth, and returns new_att_cache. New Attention.forward_chunk and DiTBlock.forward_chunk methods.cosyvoice/flow/DiT/dit.py: InputEmbedding.forward threads conv_cache through causal conv. DiT.forward unchanged in behavior (returns (output, None) tuple for consistency). New DiT.forward_chunk(x, x_offset, mask, mu, t, spks, cond, ..., conv_cache=None, att_cache=None) performs offset-aware RoPE + per-block forward_chunk, returning (output, new_conv_cache, stacked_new_att_cache).cosyvoice/flow/flow_matching.py: ConditionalCFM.forward_estimator normalized to always return a tuple. New solve_euler_chunk diffusion solver and forward_estimator_with_cache (handles both torch.nn.Module DiT and TRT-wrapped estimator with cache bindings — x_offset, conv_cache, att_cache as extra TRT inputs/outputs). CausalConditionalCFM.forward kept byte-identical to upstream (UNet path untouched); new CausalConditionalCFM.forward_chunk is the dedicated DiT/KV-cache entrypoint. compute_loss tolerant of tuple returns from DiT estimators.cosyvoice/flow/flow.py: CausalMaskedDiffWithDiT.inference now calls self.decoder.forward_chunk(..., x_offset=0) for full inference. New CausalMaskedDiffWithDiT.inference_chunk(token, token_offset, ..., conv_cache, att_cache, ..., init_cache=False, chunk_size=25, n_timesteps=10) enables streaming decode with reusable caches and prompt-aware h_offset slicing.Performance
Compatibility
DiT PyTorch module and TRT-exported engine both handled uniformly — TRT branch in forward_estimator_with_cache wires x_offset / conv_cache / att_cache as explicit TRT tensor bindings, so no hasattr/runtime-type dispatch is required at the CFM level.
Notes
4 files touched, +434 / −42 lines total.